﻿#include "GridUtils.cginc"
#include "CollisionMaterial.cginc"
#include "ContactHandling.cginc"
#include "ColliderDefinitions.cginc"
#include "Rigidbody.cginc"
#include "Bounds.cginc"
#include "Simplex.cginc"
#include "SolverParameters.cginc"
#include "AtomicDeltas.cginc"
#include "Phases.cginc"
#include "QueryDefinitions.cginc"

#define MAX_RESULTS_PER_SIMPLEX 32

#pragma kernel Clear
#pragma kernel BuildUnsortedList
#pragma kernel FindPopulatedLevels
#pragma kernel SortList
#pragma kernel BuildContactList
#pragma kernel PrefixSumColliderCounts
#pragma kernel SortContactPairs

StructuredBuffer<float4> positions;
StructuredBuffer<quaternion> orientations;
StructuredBuffer<float4> principalRadii;

StructuredBuffer<int> activeParticles;
StructuredBuffer<int> simplices;
StructuredBuffer<int> filters;    
RWStructuredBuffer<aabb> simplexBounds; // bounding box of each simplex. 

StructuredBuffer<transform> transforms;
StructuredBuffer<queryShape> shapes;
RWStructuredBuffer<uint> sortedColliderIndices;

RWStructuredBuffer<uint> colliderTypeCounts;
RWStructuredBuffer<uint> contactOffsetsPerType;
RWStructuredBuffer<uint2> unsortedContactPairs;

RWStructuredBuffer<uint> cellIndices;
RWStructuredBuffer<uint> cellOffsets;

RWStructuredBuffer<uint> cellCounts;
RWStructuredBuffer<uint> offsetInCells;

RWStructuredBuffer<contact> contacts;
RWStructuredBuffer<uint2> contactPairs;
RWStructuredBuffer<uint> dispatchBuffer;

StructuredBuffer<transform> solverToWorld;
StructuredBuffer<transform> worldToSolver;

uint maxResults;
uint queryCount;    // amount of colliders in the grid.
uint cellsPerShape; // max amount of cells a collider can be inserted into. Typically this is 8.
int shapeTypeCount; // number of different query shapes, ie: box, sphere, ray, etc. 

aabb CalculateShapeAABB(in queryShape shape)
{
    float offset = shape.contactOffset + shape.maxDistance;

    aabb bounds;
    bounds.min_ = FLT_MAX;
    bounds.max_ = FLT_MIN;
    switch (shape.type)
    {
        case SPHERE_QUERY:
            bounds.FromParticle(shape.center, shape.size.x + offset); break;
        case BOX_QUERY:
            bounds.FromEdge(shape.center - shape.size*0.5f, shape.center + shape.size * 0.5f, offset); break;
        case RAY_QUERY:
            bounds.FromEdge(shape.center, shape.center + shape.size, offset); break;
    }
    return bounds;
}

[numthreads(128, 1, 1)]
void Clear (uint3 id : SV_DispatchThreadID) 
{
    unsigned int i = id.x;

    if (i == 0)
    {
        for (int l = 0; l <= GRID_LEVELS; ++l)
            levelPopulation[l] = 0;
    }

    // clear all cell offsets to invalid, so that we can later use atomic minimum to calculate the offset.
    if (i < maxCells)
    {
        cellOffsets[i] = INVALID;
        cellCounts[i] = 0;
    }

    // clear all cell indices to invalid.
    if (i < queryCount)
    {
        for (uint j = 0; j < cellsPerShape; ++j)
            cellIndices[i*cellsPerShape+j] = INVALID;
    }
}

[numthreads(128, 1, 1)]
void BuildUnsortedList (uint3 id : SV_DispatchThreadID) 
{
    unsigned int i = id.x;
    if (i >= queryCount) return;

    // get bounds in solver space:
    aabb bounds = CalculateShapeAABB(shapes[i]).Transformed(worldToSolver[0].Multiply(transforms[i])); 

    // calculate bounds size, grid level and cell size:
    float4 size = bounds.max_ - bounds.min_;
    float maxSize = max(max (size.x, size.y), size.z);
    int level = GridLevelForSize(maxSize);
    float cellSize = CellSizeOfLevel(level);
    
    // calculate max and min cell coordinates (force 4th component to zero, might not be after expanding)
    int4 minCell = floor(bounds.min_ / cellSize);
    int4 maxCell = floor(bounds.max_ / cellSize);
    minCell[3] = 0;
    maxCell[3] = 0;

    int4 cellSpan = maxCell - minCell;
   
    // insert collider in cells:
    for (int x = 0; x <= cellSpan[0]; ++x)
    {
        for (int y = 0; y <= cellSpan[1]; ++y)
        {
            for (int z = 0; z <= cellSpan[2]; ++z)
            {
                int cellIndex = GridHash(minCell + int4(x, y, z, level));
                
                // calculate flat index of this cell into arrays:
                int k = x + y*2 + z*4 + i*cellsPerShape;

                cellIndices[k] = cellIndex;
                InterlockedAdd(cellCounts[cellIndex],1,offsetInCells[k]);
            }
        }
    }

    // atomically increase this level's population by one:
    InterlockedAdd(levelPopulation[1 + level],1);
}

[numthreads(128, 1, 1)]
void SortList (uint3 id : SV_DispatchThreadID) 
{
    uint i = id.x;
    if (i >= queryCount * cellsPerShape) return;

    uint cellIndex = cellIndices[i];

    if (cellIndex != INVALID)
    {
        // write shape to its sorted index:
        uint sortedIndex = cellOffsets[cellIndex] + offsetInCells[i];
        sortedColliderIndices[sortedIndex] = i;
    }
} 

[numthreads(128, 1, 1)]
void BuildContactList (uint3 id : SV_DispatchThreadID) 
{
    unsigned int threadIndex = id.x;

    if (threadIndex >= pointCount + edgeCount + triangleCount) return;

    uint cellCount = queryCount * cellsPerShape;
    int candidateCount = 0;
    uint candidates[MAX_RESULTS_PER_SIMPLEX];
    
    int simplexSize;
    int simplexStart = GetSimplexStartAndSize(threadIndex, simplexSize);
    
    aabb b = simplexBounds[threadIndex];

    // max size of the particle bounds in cells:
    int4 maxSize = int4(10,10,10,10);

    // build a list of candidate colliders:
    for (uint m = 1; m <= levelPopulation[0]; ++m)
    {
        uint l = levelPopulation[m];
        float cellSize = CellSizeOfLevel(l);

        int4 minCell = floor(b.min_ / cellSize);
        int4 maxCell = floor(b.max_ / cellSize);
        maxCell = minCell + min(maxCell - minCell, maxSize);

        for (int x = minCell[0]; x <= maxCell[0]; ++x)
        {
            for (int y = minCell[1]; y <= maxCell[1]; ++y)
            {
                for (int z = minCell[2]; z <= maxCell[2]; ++z)
                {
                    uint flatCellIndex = GridHash(int4(x,y,z,l));
                    uint cellStart = cellOffsets[flatCellIndex];
                    uint cellCount = cellCounts[flatCellIndex];

                    // iterate through queries in the neighbour cell
                    for (uint n = cellStart; n < cellStart + cellCount; ++n)
                    {
                        if (candidateCount < MAX_RESULTS_PER_SIMPLEX)
                            candidates[candidateCount++] = sortedColliderIndices[n] / cellsPerShape;
                    }
                   
                }
            }
        }
    }
    
    //evaluate candidates and create contacts: 
    if (candidateCount > 0)
    {
        // insert sort:
        for (int k = 1; k < candidateCount; ++k)
        {
            uint key = candidates[k];
            int j = k - 1;

            while (j >= 0 && candidates[j] > key)
                candidates[j + 1] = candidates[j--];

            candidates[j + 1] = key;
        }

        // make sure each candidate only shows up once in the list:
        int first = 0, contactCount = 0;
        while(++first != candidateCount)
        {
            if (candidates[contactCount] != candidates[first])
                candidates[++contactCount] = candidates[first];
        }
        contactCount++;

        // append contacts:
        for (int i = 0; i < contactCount; i++)
        {
            int c = candidates[i];
           
            // get shape bounds in solver space:
            aabb colliderBoundsSS = CalculateShapeAABB(shapes[c]).Transformed(worldToSolver[0].Multiply(transforms[c])); 

            // check if any simplex particle and the collider should collide:
            bool shouldCollide = false;
            int colliderCategory = shapes[c].filter & CategoryMask;
            int colliderMask = (shapes[c].filter & MaskMask) >> 16;
            for (int j = 0; j < simplexSize; ++j)
            {
                int simplexCategory = filters[simplices[simplexStart + j]] & CategoryMask;
                int simplexMask = (filters[simplices[simplexStart + j]] & MaskMask) >> 16;
                shouldCollide = shouldCollide || ((simplexCategory & colliderMask) != 0 && (simplexMask & colliderCategory) != 0);
            }

            if (shouldCollide && b.IntersectsAabb(colliderBoundsSS))
            {
                uint count;
                InterlockedAdd(dispatchBuffer[7], 1, count);

                // technically incorrect, as number of pairs != number of contacts but
                // we will ignore either excess pairs or contacts.
                if (count < maxResults) 
                {
                    // increment the amount of contacts for this shape type:
                    InterlockedAdd(colliderTypeCounts[shapes[c].type],1);

                    // enqueue a new contact pair:
                    unsortedContactPairs[count] = uint2(threadIndex,c);

                    InterlockedMax(dispatchBuffer[4],(count + 1) / 128 + 1);
                }
            }
        }
    }   
}

[numthreads(1, 1, 1)]
void PrefixSumColliderCounts (uint3 id : SV_DispatchThreadID) 
{
    contactOffsetsPerType[0] = 0;
    int i;

    for (i = 0; i < shapeTypeCount; ++i)
    {
        contactOffsetsPerType[i+1] = contactOffsetsPerType[i] + colliderTypeCounts[i];

        // write amount of pairs per collider type in the dispatch buffer:
        dispatchBuffer[8 + i*4] = colliderTypeCounts[i] / 128 + 1;
        dispatchBuffer[8 + i*4 + 3] = colliderTypeCounts[i];
    }
}

[numthreads(128, 1, 1)]
void SortContactPairs (uint3 id : SV_DispatchThreadID) 
{
    uint i = id.x;
    if (i >= dispatchBuffer[7] || i >= maxResults) return;

    uint2 pair = unsortedContactPairs[i];
    int shapeType = (int)shapes[pair.y].type;

    // decrement amount of pairs for the given collider type:
    uint count;
    InterlockedAdd(colliderTypeCounts[shapeType],-1, count);

    // write the pair directly at its position in the sorted array:
    contactPairs[contactOffsetsPerType[shapeType] + count - 1] = pair;
}



